#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 19 15:54:38 2022

@author: qiguangyao
"""


#%%Lib
import copy
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns 
from scipy import asarray as ar,exp
from scipy.optimize import curve_fit
import math
import pingouin as pg
from sklearn import linear_model
from pylab import cos
import pandas as pd
import random
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.sandbox.stats.multicomp import multipletests # for multiple comparisons correction
from statsmodels.stats.multicomp import pairwise_tukeyhsd
print("__file Output:",__file__)
#%%functions
import scipy.stats
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

def adjust_spines(ax, spines):
    for loc, spine in ax.spines.items():
        if loc in spines:
            spine.set_position(('outward', 10))  # outward by 10 points
        else:
            spine.set_color('none')  # don't draw spine

    # turn off ticks where there is no spine
    if 'left' in spines:
        ax.yaxis.set_ticks_position('left')
    else:
        # no yaxis ticks
        ax.yaxis.set_ticks([])

    if 'bottom' in spines:
        ax.xaxis.set_ticks_position('bottom')
    else:
        # no xaxis ticks
        ax.xaxis.set_ticks([])
        
def gaus(x,a,x0,sigma):
    return a*(1/sigma*np.sqrt(2*np.pi))*exp(-(x-x0)**2/(2*sigma**2))

def gaussian(X, amp, cen, wid):
    return amp * exp(-(X-cen)**2 / wid)

def getPossionPDF(mu,x):
    if x > 170:
        x =170
    mu = mu + 0.01
    if x<0:
        x = 0
    # x[x<0]=0
    x = copy.deepcopy(round(x))
    out = math.exp(-mu)*(mu**x)/math.factorial(x)
    if out<0:
        out = 0
    return out

#tuning curve fitting
def vonMisesFunction(x,b,a,u):
    # import math
#    print(x - u)
    out = b + a*cos(x - u)
    out = np.array(out)
    out[out<0]=0
    # if out<0:
    #     out = 0
    return out

def getvonMisesParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(vonMisesFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def getExpParas(x,y):
    """
    x:hand position
    y:firing rate
    """
    init_vals = [1, 0, 1]  # for [b,a,u]
    best_vals, covar = curve_fit(expFunction, x, y, p0=init_vals,maxfev=500000)
    return best_vals

def expFunction(x, a, b, c):
    return a * np.exp(-b * x) + c
#%% ------------figure 1 ---------------- 
#load data
fig1Data = pkl.load(open('fig1Data.pickle','rb'))

#1C
disparity =fig1Data['disparity']
drift =fig1Data['drift']
meanDrif =fig1Data['meanDrif']
drifGrouSimu =fig1Data['drifGrouSimu']
mdrfSimu = fig1Data['mdrfSimu']

#%%fig1G
disp = np.unique(disparity)
with plt.style.context('style_paper.mplstyle'):    
    f, ax1 = plt.subplots(ncols=1, nrows=1, sharey=True,figsize=[3.54/1.8,3.54/2])
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    ax1.plot(np.linspace(-50,50,10),np.linspace(-50,50,10),color = '#f08342')
    ax1.plot(np.linspace(-50,50,10),[0 for i in range(10)],color = colors[1]) 
    ax1.set_zorder(0)
    for i, d in enumerate(disp):
        ax1.scatter(np.array([d for k in range(100)])+3, np.random.choice(drifGrouSimu[d],100,replace = False),color = 'gray',s = 8,alpha = .3)
        ax1.scatter(disparity[disparity == d], drift[disparity == d],color = 'k',s = 8,alpha = .3)#099d84
   
    ax1.plot(disp,meanDrif,color = 'k',label = 'Data')##099d84

    ax1.plot(disp+3,mdrfSimu,color = 'gray',ls = '--',label = 'CI')
    plt.legend(loc = 'upper left',bbox_to_anchor=[0,1],labelspacing=1)
    
    
    plt.xlabel('Disparity (deg)')
    plt.ylabel('Drift (deg)')
    # plt.legend()
    plt.xlim([-52,52])
    plt.ylim([-52,52])
    plt.xticks(np.arange(-50,51,25))
    plt.yticks(np.arange(-50,51,25))
    plt.tight_layout()
    fileName = 'fig1C_20200903NDispDrifDataModelCI.pdf'
    # plt.savefig(fileName,dpi = 600)
plt.show()